{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 稠密连接网络（DenseNet）\n",
    "\n",
    "ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个：稠密连接网络（DenseNet）。\n",
    "\n",
    "<div align=center>\n",
    "<img width=\"400\" src=\"../imgs/dense.jpg\"/>\n",
    "</div>\n",
    "<div align=center> ResNet（左）与DenseNet（右）在跨层连接上的主要区别：使用相加和使用连结</div>\n",
    "\n",
    "前后相邻的运算抽象为模块$A$和模块$B$。与ResNet的主要区别在于，DenseNet里模块$B$的输出不是像ResNet那样和模块$A$的输出相加，而是在通道维上连结。\n",
    "\n",
    "DenseNet的主要构建模块是\n",
    "1. 稠密块（dense block）定义了输入和输出是如何连结的\n",
    "2. 过渡层（transition layer）用来控制通道数 1x1卷积 压缩信道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构，我们首先在`conv_block`函数里实现这个结构。\n",
    "\n",
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.enabled = True\n",
    "\n",
    "import dl_utils\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def conv_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(nn.BatchNorm2d(in_channels), \n",
    "                        nn.ReLU(),\n",
    "                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dense Block 由多个 conv block 构成，每次信道翻倍\n",
    "\n",
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self, num_convs, in_channels, out_channels):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        \n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            # 信道加倍\n",
    "            in_c = in_channels + i * out_channels\n",
    "            net.append(conv_block(in_c, out_channels))\n",
    "            \n",
    "        self.net = nn.ModuleList(net)\n",
    "        \n",
    "        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数\n",
    "\n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            #print(X[0][0][0])\n",
    "            #相当于每次输入，都是带有之前的输出层的全部特征\n",
    "            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "<div align=center>\n",
    "<img width=\"600\" src=\"../imgs/densenet1.png\"/>\n",
    "</div>\n",
    "<div align=center>DenseBlock 示意图</div>\n",
    "\n",
    "\n",
    "<div align=center>\n",
    "<img width=\"600\" src=\"../imgs/densenet2.png\"/>\n",
    "</div>\n",
    "<div align=center>DenseNet</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DenseBlock(\n",
      "  (net): ModuleList(\n",
      "    (0): Sequential(\n",
      "      (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (1): ReLU()\n",
      "      (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    )\n",
      "    (1): Sequential(\n",
      "      (0): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (1): ReLU()\n",
      "      (2): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    )\n",
      "    (2): Sequential(\n",
      "      (0): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (1): ReLU()\n",
      "      (2): Conv2d(80, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    )\n",
      "    (3): Sequential(\n",
      "      (0): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (1): ReLU()\n",
      "      (2): Conv2d(112, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    )\n",
      "  )\n",
      ")\n",
      "torch.Size([1, 144, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "net = DenseBlock(4,16,32)\n",
    "print(net)\n",
    "x = torch.randn(1,16,32,32)\n",
    "x  = net(x)\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 过渡层\n",
    "\n",
    "由于每个稠密块都会带来通道数的增加，使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过$1\\times1$卷积层来减小通道数，并使用步幅为2的平均池化层减半高和宽，从而进一步降低模型复杂度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transition_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "            nn.BatchNorm2d(in_channels), \n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(in_channels, out_channels, kernel_size=1),\n",
    "            nn.AvgPool2d(kernel_size=2, stride=2))\n",
    "    return blk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 23, 8, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 4, 4])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = DenseBlock(2, 3, 10)\n",
    "X = torch.rand(4, 3, 8, 8)\n",
    "Y = blk(X)\n",
    "print(Y.shape) # torch.Size([4, 23, 8, 8])\n",
    "\n",
    "#对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10，高和宽均减半。\n",
    "blk = transition_block(23, 10)\n",
    "blk(Y).shape # torch.Size([4, 10, 4, 4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DenseNet模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),\n",
    "        nn.BatchNorm2d(64), \n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
    "\n",
    "num_channels, growth_rate = 64, 32  # num_channels为当前的通道数\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    \n",
    "    DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
    "    \n",
    "    net.add_module(\"DenseBlosk_%d\" % i, DB)\n",
    "    \n",
    "    # 上一个稠密块的输出通道数\n",
    "    num_channels = DB.out_channels\n",
    "    \n",
    "    # 在稠密块之间加入通道数减半的过渡层\n",
    "    if i != len(num_convs_in_dense_blocks) - 1:\n",
    "        net.add_module(\"transition_block_%d\" % i, transition_block(num_channels, num_channels // 2))\n",
    "        num_channels = num_channels // 2\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.add_module(\"BN\", nn.BatchNorm2d(num_channels))\n",
    "net.add_module(\"relu\", nn.ReLU())\n",
    "net.add_module(\"global_avg_pool\", dl_utils.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)\n",
    "net.add_module(\"fc\", nn.Sequential(dl_utils.FlattenLayer(), nn.Linear(num_channels, 10))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 8])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(16,3,8,8)\n",
    "x.shape[2:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GlobalAvgPool2d(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GlobalAvgPool2d, self).__init__()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return nn.functional.max_pool2d(x, x.shape[2:])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 3, 8, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3, 1, 1])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(x.shape)\n",
    "pool = GlobalAvgPool2d()\n",
    "pool(x).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "1  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "2  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "3  output shape:\t torch.Size([1, 64, 24, 24])\n",
      "DenseBlosk_0  output shape:\t torch.Size([1, 192, 24, 24])\n",
      "transition_block_0  output shape:\t torch.Size([1, 96, 12, 12])\n",
      "DenseBlosk_1  output shape:\t torch.Size([1, 224, 12, 12])\n",
      "transition_block_1  output shape:\t torch.Size([1, 112, 6, 6])\n",
      "DenseBlosk_2  output shape:\t torch.Size([1, 240, 6, 6])\n",
      "transition_block_2  output shape:\t torch.Size([1, 120, 3, 3])\n",
      "DenseBlosk_3  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "BN  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "relu  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 248, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((1, 1, 96, 96))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.4076, train acc 0.857, test acc 0.879, time 42.9 sec\n",
      "epoch 2, loss 0.1220, train acc 0.911, test acc 0.892, time 38.4 sec\n",
      "epoch 3, loss 0.0702, train acc 0.923, test acc 0.907, time 37.7 sec\n",
      "epoch 4, loss 0.0465, train acc 0.932, test acc 0.905, time 37.5 sec\n",
      "epoch 5, loss 0.0338, train acc 0.938, test acc 0.915, time 37.3 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter = dl_utils.load_data_fashion_mnist(batch_size, resize=96)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "dl_utils.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
